'''
This script runs an ensemble on our North Carolina JSON, recording data on the Republican seat count, mean-median, partisan Gini, Republican vote share, efficiency gap, and the number of cut edges for 100,000 districting plans. We first import the necessary Python libraries, including the GerryChain library.
'''

### Imports ###
import os
from functools import partial
import json
import csv
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import sys

from gerrychain import (
    Election,
    Graph,
    MarkovChain,
    Partition,
    accept,
    constraints,
    updaters,
)

from gerrychain.metrics import efficiency_gap, mean_median, partisan_gini
from gerrychain.proposals import recom
from gerrychain.updaters import cut_edges
from gerrychain.tree import recursive_tree_part

# State specifications
state = 'northcarolina'
state_abbr = "NC"
num_dist = 13
election_names = ["GOV08", "SEN08", "SEN10", "GOV12", "SEN14", "PRES12", "PRES16", "SEN16", "GOV16"]
election_columns = [
    ["EL08G_GV_D", "EL08G_GV_R"],
    ["EL08G_USS_", "EL08G_US_1"],
    ["EL10G_USS_", "EL10G_US_1"],
    ["EL12G_GV_D", "EL12G_GV_R"],
    ["EL14G_US_1", "EL14G_USS_"],
    ["EL12G_PR_D", "EL12G_PR_R"],
    ["EL16G_PR_D", "EL16G_PR_R"],
    ["EL16G_US_1", "EL16G_USS_"],
    ["EL16G_GV_D", "EL16G_GV_R"]
] #DEM, REP

# Chain specificiations
pop_bound     = 0.01
num_steps     = 100000
dump_interval = 10000

# to test Markov Chain, try with fewer steps (otherwise this script will take hours):
# num_steps     = 100
# dump_interval = 10

newdir = "../outputs/" + state_abbr + "output/"
os.makedirs(os.path.dirname(newdir + "init.txt"), exist_ok=True)
with open(newdir + "init.txt", "w") as f:
    f.write("Created Folder")

graph_path = "../jsons/" + state + ".json"
graph = Graph.from_json(graph_path)

pop_count = 0
for i in graph.nodes:
    pop_count += graph.nodes[i]["TOTPOP"]

pop = pop_count

my_updaters = {
    "population": updaters.Tally("TOTPOP", alias="population"),
    "cut_edges": cut_edges}
my_updaters.update({
    e: Election(e, {"republican": r, "democratic":d}) for e, (d,r) in zip(election_names, election_columns)
})

initial_partition = Partition(graph,
                              assignment="newplan",
                              updaters=my_updaters)

proposal = partial(recom,
                   pop_col = "TOTPOP",
                   pop_target = pop/num_dist,
                   epsilon = pop_bound,
                   node_repeats = 3)

compactness_bound = constraints.UpperBound(
    lambda p: len(p["cut_edges"]), 2 * len(initial_partition["cut_edges"])
)

chain = MarkovChain(
    proposal=proposal,
    constraints=[
        constraints.within_percent_of_ideal_population(initial_partition, pop_bound),
    ],
    accept=accept.always_accept,
    initial_state=initial_partition,
    total_steps=num_steps,
)

data = {e:[] for e in election_names}
t = 0

for election in election_names:
    with open(newdir + state + election + "_data_specs.txt", "w") as f:
        f.write("state: \n")
        f.write(state)
        f.write("\nelection: \n")
        f.write(election)
        f.write("\npopulation bound: \n")
        f.write(str(pop_bound))
        f.write("\nnumber of steps: \n")
        f.write(str(num_steps))
        f.write("\ndumping interval: \n")
        f.write(str(dump_interval))

for step in chain:
    for election in election_names:
        data[election].append([step[election].wins("republican"),
                             mean_median(step[election]),
                             partisan_gini(step[election]),
                             step[election].percents("republican"),
                             efficiency_gap(step[election]),
                             len(step['cut_edges'])])
        if election == "SEN16" and step[election].wins("republican") == 11 and partisan_gini(step[election]) < 0.0116:
            with open(newdir + "PG_example_assignments.txt", "a") as f:
                f.write(f"{str(dict(step.assignment))}\n")
            f.close()
            with open(newdir + "PG_example_scores.txt", "a") as f:
                f.write(f"{partisan_gini(step[election])}\n")
            f.close()
    t += 1
    if t % dump_interval == 0:
        for election in election_names:
            with open(newdir + state + election + "_data" + str(t) + ".csv", "w") as f:
                writer = csv.writer(f, lineterminator="\n")
                writer.writerow(['seats', 'mm', 'pg', 'vs', 'eg', 'ce'])
                writer.writerows(data[election][-dump_interval:])
